home *** CD-ROM | disk | FTP | other *** search
/ Enter 2006 September / Enter 09 2006.iso / Internet / SpamExperts Home 1.1 / SpamExperts Home.exe / lib / spamexperts.modules / spambayes / storage.pyc (.txt) < prev    next >
Encoding:
Python Compiled Bytecode  |  2006-07-14  |  32.5 KB  |  1,064 lines

  1. # Source Generated with Decompyle++
  2. # File: in.pyc (Python 2.4)
  3.  
  4. '''storage.py - Spambayes database management framework.
  5.  
  6. Classes:
  7.     PickledClassifier - Classifier that uses a pickle db
  8.     DBDictClassifier - Classifier that uses a shelve db
  9.     PGClassifier - Classifier that uses postgres
  10.     mySQLClassifier - Classifier that uses mySQL
  11.     CBDClassifier - Classifier that uses CDB
  12.     ZODBClassifier - Classifier that uses ZODB
  13.     ZEOClassifier - Classifier that uses ZEO
  14.     Trainer - Classifier training observer
  15.     SpamTrainer - Trainer for spam
  16.     HamTrainer - Trainer for ham
  17.  
  18. Abstract:
  19.     *Classifier are subclasses of Classifier (classifier.Classifier)
  20.     that add automatic state store/restore function to the Classifier class.
  21.     All SQL based classifiers are subclasses of SQLClassifier, which is a
  22.     subclass of Classifier.
  23.  
  24.     PickledClassifier is a Classifier class that uses a cPickle
  25.     datastore.  This database is relatively small, but slower than other
  26.     databases.
  27.  
  28.     DBDictClassifier is a Classifier class that uses a database
  29.     store.
  30.  
  31.     Trainer is concrete class that observes a Corpus and trains a
  32.     Classifier object based upon movement of messages between corpora  When
  33.     an add message notification is received, the trainer trains the
  34.     database with the message, as spam or ham as appropriate given the
  35.     type of trainer (spam or ham).  When a remove message notification
  36.     is received, the trainer untrains the database as appropriate.
  37.  
  38.     SpamTrainer and HamTrainer are convenience subclasses of Trainer, that
  39.     initialize as the appropriate type of Trainer
  40.  
  41. To Do:
  42.     o Suggestions?
  43.  
  44.     '''
  45. __author__ = 'Neale Pickett <neale@woozle.org>, Tim Stone <tim@fourstonesExpressions.com>'
  46. __credits__ = 'All the spambayes contributors.'
  47.  
  48. try:
  49.     (True, False)
  50. except NameError:
  51.     (True, False) = (1, 0)
  52.     
  53.     def bool(val):
  54.         return not (not val)
  55.  
  56.  
  57. import os
  58. import sys
  59. import time
  60. import types
  61. from spambayes import classifier
  62. from spambayes.Options import options, get_pathname_option
  63. import cPickle as pickle
  64. import errno
  65. import shelve
  66. from spambayes import cdb
  67. from spambayes import dbmstorage
  68. oldShelvePickler = shelve.Pickler
  69.  
  70. def binaryDefaultPickler(f, binary = 1):
  71.     return oldShelvePickler(f, binary)
  72.  
  73. shelve.Pickler = binaryDefaultPickler
  74. PICKLE_TYPE = 1
  75. NO_UPDATEPROBS = False
  76. UPDATEPROBS = True
  77.  
  78. class PickledClassifier(classifier.Classifier):
  79.     '''Classifier object persisted in a pickle'''
  80.     
  81.     def __init__(self, db_name):
  82.         classifier.Classifier.__init__(self)
  83.         self.db_name = db_name
  84.         self.load()
  85.  
  86.     
  87.     def load(self):
  88.         '''Load this instance from the pickle.'''
  89.         if options[('globals', 'verbose')]:
  90.             print >>sys.stderr, 'Loading state from', self.db_name, 'pickle'
  91.         
  92.         tempbayes = None
  93.         
  94.         try:
  95.             fp = open(self.db_name, 'rb')
  96.         except IOError:
  97.             e = None
  98.             if e.errno != errno.ENOENT:
  99.                 raise 
  100.             
  101.         except:
  102.             e.errno != errno.ENOENT
  103.  
  104.         tempbayes = pickle.load(fp)
  105.         fp.close()
  106.         if tempbayes:
  107.             classifier.Classifier.__setstate__(self, tempbayes.__getstate__())
  108.             if options[('globals', 'verbose')]:
  109.                 print >>sys.stderr, '%s is an existing pickle, with %d ham and %d spam' % (self.db_name, self.nham, self.nspam)
  110.             
  111.         elif options[('globals', 'verbose')]:
  112.             print >>sys.stderr, self.db_name, 'is a new pickle'
  113.         
  114.         self.wordinfo = { }
  115.         self.nham = 0
  116.         self.nspam = 0
  117.  
  118.     
  119.     def store(self):
  120.         '''Store self as a pickle'''
  121.         if options[('globals', 'verbose')]:
  122.             print >>sys.stderr, 'Persisting', self.db_name, 'as a pickle'
  123.         
  124.         tmp = self.db_name + '.tmp'
  125.         
  126.         try:
  127.             fp = open(tmp, 'wb')
  128.             pickle.dump(self, fp, PICKLE_TYPE)
  129.             fp.close()
  130.         except IOError:
  131.             e = None
  132.             if options[('globals', 'verbose')]:
  133.                 print >>sys.stderr, 'Failed update: ' + str(e)
  134.             
  135.             if fp is not None:
  136.                 os.remove(tmp)
  137.             
  138.             raise 
  139.  
  140.         
  141.         try:
  142.             os.rename(tmp, self.db_name)
  143.         except OSError:
  144.             os.rename(self.db_name, self.db_name + '.bak')
  145.             os.rename(tmp, self.db_name)
  146.             os.remove(self.db_name + '.bak')
  147.  
  148.  
  149.     
  150.     def close(self):
  151.         pass
  152.  
  153.  
  154. WORD_DELETED = 'D'
  155. WORD_CHANGED = 'C'
  156. STATE_KEY = 'saved state'
  157.  
  158. class DBDictClassifier(classifier.Classifier):
  159.     '''Classifier object persisted in a caching database'''
  160.     
  161.     def __init__(self, db_name, mode = 'c'):
  162.         '''Constructor(database name)'''
  163.         classifier.Classifier.__init__(self)
  164.         self.statekey = STATE_KEY
  165.         self.mode = mode
  166.         self.db_name = db_name
  167.         self.load()
  168.  
  169.     
  170.     def close(self):
  171.         
  172.         def noop():
  173.             pass
  174.  
  175.         getattr(self.db, 'close', noop)()
  176.         getattr(self.dbm, 'close', noop)()
  177.         if hasattr(self, 'db'):
  178.             del self.db
  179.         
  180.         if hasattr(self, 'dbm'):
  181.             del self.dbm
  182.         
  183.         if options[('globals', 'verbose')]:
  184.             print >>sys.stderr, 'Closed', self.db_name, 'database'
  185.         
  186.  
  187.     
  188.     def load(self):
  189.         '''Load state from database'''
  190.         if options[('globals', 'verbose')]:
  191.             print >>sys.stderr, 'Loading state from', self.db_name, 'database'
  192.         
  193.         self.dbm = dbmstorage.open(self.db_name, self.mode)
  194.         self.db = shelve.Shelf(self.dbm)
  195.         if self.db.has_key(self.statekey):
  196.             t = self.db[self.statekey]
  197.             if t[0] != classifier.PICKLE_VERSION:
  198.                 raise ValueError("Can't unpickle -- version %s unknown" % t[0])
  199.             
  200.             (self.nspam, self.nham) = t[1:]
  201.             if options[('globals', 'verbose')]:
  202.                 print >>sys.stderr, '%s is an existing database, with %d spam and %d ham' % (self.db_name, self.nspam, self.nham)
  203.             
  204.         elif options[('globals', 'verbose')]:
  205.             print >>sys.stderr, self.db_name, 'is a new database'
  206.         
  207.         self.nspam = 0
  208.         self.nham = 0
  209.         self.wordinfo = { }
  210.         self.changed_words = { }
  211.  
  212.     
  213.     def store(self):
  214.         '''Place state into persistent store'''
  215.         if options[('globals', 'verbose')]:
  216.             print >>sys.stderr, 'Persisting', self.db_name, 'state in database'
  217.         
  218.         for key, flag in self.changed_words.iteritems():
  219.             if flag is WORD_CHANGED:
  220.                 val = self.wordinfo[key]
  221.                 self.db[key] = val.__getstate__()
  222.                 continue
  223.             if flag is WORD_DELETED:
  224.                 if not key not in self.wordinfo:
  225.                     raise AssertionError, 'Should not have a wordinfo for words flagged for delete'
  226.                 
  227.                 try:
  228.                     del self.db[key]
  229.                 except KeyError:
  230.                     pass
  231.                 except:
  232.                     None<EXCEPTION MATCH>KeyError
  233.                 
  234.  
  235.             None<EXCEPTION MATCH>KeyError
  236.             raise RuntimeError, 'Unknown flag value'
  237.         
  238.         self.changed_words = { }
  239.         self._write_state_key()
  240.         self.db.sync()
  241.  
  242.     
  243.     def _write_state_key(self):
  244.         self.db[self.statekey] = (classifier.PICKLE_VERSION, self.nspam, self.nham)
  245.  
  246.     
  247.     def _post_training(self):
  248.         '''This is called after training on a wordstream.  We ensure that the
  249.         database is in a consistent state at this point by writing the state
  250.         key.'''
  251.         self._write_state_key()
  252.  
  253.     
  254.     def _wordinfoget(self, word):
  255.         if isinstance(word, unicode):
  256.             word = word.encode('utf-8')
  257.         
  258.         
  259.         try:
  260.             return self.wordinfo[word]
  261.         except KeyError:
  262.             ret = None
  263.             if self.changed_words.get(word) is not WORD_DELETED:
  264.                 r = self.db.get(word)
  265.                 if r:
  266.                     ret = self.WordInfoClass()
  267.                     ret.__setstate__(r)
  268.                     self.wordinfo[word] = ret
  269.                 
  270.             
  271.             return ret
  272.  
  273.  
  274.     
  275.     def _wordinfoset(self, word, record):
  276.         if isinstance(word, unicode):
  277.             word = word.encode('utf-8')
  278.         
  279.         if record.spamcount + record.hamcount <= 1:
  280.             self.db[word] = record.__getstate__()
  281.             
  282.             try:
  283.                 del self.changed_words[word]
  284.             except KeyError:
  285.                 pass
  286.  
  287.             
  288.             try:
  289.                 del self.wordinfo[word]
  290.             except KeyError:
  291.                 pass
  292.             except:
  293.                 None<EXCEPTION MATCH>KeyError
  294.             
  295.  
  296.         None<EXCEPTION MATCH>KeyError
  297.         self.wordinfo[word] = record
  298.         self.changed_words[word] = WORD_CHANGED
  299.  
  300.     
  301.     def _wordinfodel(self, word):
  302.         if isinstance(word, unicode):
  303.             word = word.encode('utf-8')
  304.         
  305.         del self.wordinfo[word]
  306.         self.changed_words[word] = WORD_DELETED
  307.  
  308.     
  309.     def _wordinfokeys(self):
  310.         wordinfokeys = self.db.keys()
  311.         del wordinfokeys[wordinfokeys.index(self.statekey)]
  312.         return wordinfokeys
  313.  
  314.  
  315.  
  316. class SQLClassifier(classifier.Classifier):
  317.     
  318.     def __init__(self, db_name):
  319.         '''Constructor(database name)'''
  320.         classifier.Classifier.__init__(self)
  321.         self.statekey = STATE_KEY
  322.         self.db_name = db_name
  323.         self.load()
  324.  
  325.     
  326.     def close(self):
  327.         '''Release all database resources'''
  328.         pass
  329.  
  330.     
  331.     def load(self):
  332.         '''Load state from the database'''
  333.         raise NotImplementedError, 'must be implemented in subclass'
  334.  
  335.     
  336.     def store(self):
  337.         '''Save state to the database'''
  338.         self._set_row(self.statekey, self.nspam, self.nham)
  339.  
  340.     
  341.     def cursor(self):
  342.         '''Return a new db cursor'''
  343.         raise NotImplementedError, 'must be implemented in subclass'
  344.  
  345.     
  346.     def fetchall(self, c):
  347.         '''Return all rows as a dict'''
  348.         raise NotImplementedError, 'must be implemented in subclass'
  349.  
  350.     
  351.     def commit(self, c):
  352.         '''Commit the current transaction - may commit at db or cursor'''
  353.         raise NotImplementedError, 'must be implemented in subclass'
  354.  
  355.     
  356.     def create_bayes(self):
  357.         '''Create a new bayes table'''
  358.         c = self.cursor()
  359.         c.execute(self.table_definition)
  360.         self.commit(c)
  361.  
  362.     
  363.     def _get_row(self, word):
  364.         '''Return row matching word'''
  365.         
  366.         try:
  367.             c = self.cursor()
  368.             c.execute('select * from bayes  where word=%s', (word,))
  369.         except Exception:
  370.             e = None
  371.             print >>sys.stderr, 'error:', (e, word)
  372.             raise 
  373.  
  374.         rows = self.fetchall(c)
  375.         if rows:
  376.             return rows[0]
  377.         else:
  378.             return { }
  379.  
  380.     
  381.     def _set_row(self, word, nspam, nham):
  382.         c = self.cursor()
  383.         if self._has_key(word):
  384.             c.execute('update bayes  set nspam=%s,nham=%s  where word=%s', (nspam, nham, word))
  385.         else:
  386.             c.execute('insert into bayes  (nspam, nham, word)  values (%s, %s, %s)', (nspam, nham, word))
  387.         self.commit(c)
  388.  
  389.     
  390.     def _delete_row(self, word):
  391.         c = self.cursor()
  392.         c.execute('delete from bayes  where word=%s', (word,))
  393.         self.commit(c)
  394.  
  395.     
  396.     def _has_key(self, key):
  397.         c = self.cursor()
  398.         c.execute('select word from bayes  where word=%s', (key,))
  399.         return len(self.fetchall(c)) > 0
  400.  
  401.     
  402.     def _wordinfoget(self, word):
  403.         if isinstance(word, unicode):
  404.             word = word.encode('utf-8')
  405.         
  406.         row = self._get_row(word)
  407.         if row:
  408.             item = self.WordInfoClass()
  409.             item.__setstate__((row['nspam'], row['nham']))
  410.             return item
  411.         else:
  412.             return self.WordInfoClass()
  413.  
  414.     
  415.     def _wordinfoset(self, word, record):
  416.         if isinstance(word, unicode):
  417.             word = word.encode('utf-8')
  418.         
  419.         self._set_row(word, record.spamcount, record.hamcount)
  420.  
  421.     
  422.     def _wordinfodel(self, word):
  423.         if isinstance(word, unicode):
  424.             word = word.encode('utf-8')
  425.         
  426.         self._delete_row(word)
  427.  
  428.     
  429.     def _wordinfokeys(self):
  430.         c = self.cursor()
  431.         c.execute('select word from bayes')
  432.         rows = self.fetchall(c)
  433.         return [ r[0] for r in rows ]
  434.  
  435.  
  436.  
  437. class PGClassifier(SQLClassifier):
  438.     '''Classifier object persisted in a Postgres database'''
  439.     
  440.     def __init__(self, db_name):
  441.         self.table_definition = "create table bayes (  word bytea not null default '',  nspam integer not null default 0,  nham integer not null default 0,  primary key(word))"
  442.         SQLClassifier.__init__(self, db_name)
  443.  
  444.     
  445.     def cursor(self):
  446.         return self.db.cursor()
  447.  
  448.     
  449.     def fetchall(self, c):
  450.         return c.dictfetchall()
  451.  
  452.     
  453.     def commit(self, c):
  454.         self.db.commit()
  455.  
  456.     
  457.     def load(self):
  458.         '''Load state from database'''
  459.         import psycopg
  460.         if options[('globals', 'verbose')]:
  461.             print >>sys.stderr, 'Loading state from', self.db_name, 'database'
  462.         
  463.         self.db = psycopg.connect(self.db_name)
  464.         c = self.cursor()
  465.         
  466.         try:
  467.             c.execute('select count(*) from bayes')
  468.         except psycopg.ProgrammingError:
  469.             self.db.rollback()
  470.             self.create_bayes()
  471.  
  472.         if self._has_key(self.statekey):
  473.             row = self._get_row(self.statekey)
  474.             self.nspam = row['nspam']
  475.             self.nham = row['nham']
  476.             if options[('globals', 'verbose')]:
  477.                 print >>sys.stderr, '%s is an existing database, with %d spam and %d ham' % (self.db_name, self.nspam, self.nham)
  478.             
  479.         elif options[('globals', 'verbose')]:
  480.             print >>sys.stderr, self.db_name, 'is a new database'
  481.         
  482.         self.nspam = 0
  483.         self.nham = 0
  484.  
  485.  
  486.  
  487. class mySQLClassifier(SQLClassifier):
  488.     '''Classifier object persisted in a mySQL database
  489.  
  490.     It is assumed that the database already exists, and that the mySQL
  491.     server is currently running.'''
  492.     
  493.     def __init__(self, data_source_name):
  494.         self.table_definition = "create table bayes (  word varchar(255) not null default '',  nspam integer not null default 0,  nham integer not null default 0,  primary key(word));"
  495.         self.host = 'localhost'
  496.         self.username = 'root'
  497.         self.password = ''
  498.         db_name = 'spambayes'
  499.         source_info = data_source_name.split()
  500.         for info in source_info:
  501.             if info.startswith('host'):
  502.                 self.host = info[5:]
  503.                 continue
  504.             if info.startswith('user'):
  505.                 self.username = info[5:]
  506.                 continue
  507.             if info.startswith('pass'):
  508.                 self.username = info[5:]
  509.                 continue
  510.             if info.startswith('dbname'):
  511.                 db_name = info[7:]
  512.                 continue
  513.         
  514.         SQLClassifier.__init__(self, db_name)
  515.  
  516.     
  517.     def cursor(self):
  518.         return self.db.cursor()
  519.  
  520.     
  521.     def fetchall(self, c):
  522.         return c.fetchall()
  523.  
  524.     
  525.     def commit(self, c):
  526.         self.db.commit()
  527.  
  528.     
  529.     def load(self):
  530.         '''Load state from database'''
  531.         import MySQLdb
  532.         if options[('globals', 'verbose')]:
  533.             print >>sys.stderr, 'Loading state from', self.db_name, 'database'
  534.         
  535.         self.db = MySQLdb.connect(host = self.host, db = self.db_name, user = self.username, passwd = self.password)
  536.         c = self.cursor()
  537.         
  538.         try:
  539.             c.execute('select count(*) from bayes')
  540.         except MySQLdb.ProgrammingError:
  541.             
  542.             try:
  543.                 self.db.rollback()
  544.             except MySQLdb.NotSupportedError:
  545.                 pass
  546.  
  547.             self.create_bayes()
  548.  
  549.         if self._has_key(self.statekey):
  550.             row = self._get_row(self.statekey)
  551.             self.nspam = int(row[1])
  552.             self.nham = int(row[2])
  553.             if options[('globals', 'verbose')]:
  554.                 print >>sys.stderr, '%s is an existing database, with %d spam and %d ham' % (self.db_name, self.nspam, self.nham)
  555.             
  556.         elif options[('globals', 'verbose')]:
  557.             print >>sys.stderr, self.db_name, 'is a new database'
  558.         
  559.         self.nspam = 0
  560.         self.nham = 0
  561.  
  562.     
  563.     def _wordinfoget(self, word):
  564.         if isinstance(word, unicode):
  565.             word = word.encode('utf-8')
  566.         
  567.         row = self._get_row(word)
  568.         if row:
  569.             item = self.WordInfoClass()
  570.             item.__setstate__((row[1], row[2]))
  571.             return item
  572.         else:
  573.             return None
  574.  
  575.  
  576.  
  577. class CDBClassifier(classifier.Classifier):
  578.     '''A classifier that uses a CDB database.
  579.  
  580.     A CDB wordinfo database is quite small and fast but is slow to update.
  581.     It is appropriate if training is done rarely (e.g. monthly or weekly
  582.     using archived ham and spam).
  583.     '''
  584.     
  585.     def __init__(self, db_name):
  586.         classifier.Classifier.__init__(self)
  587.         self.db_name = db_name
  588.         self.statekey = STATE_KEY
  589.         self.load()
  590.  
  591.     
  592.     def _WordInfoFactory(self, counts):
  593.         (ham, spam) = counts.split(',')
  594.         wi = classifier.WordInfo()
  595.         wi.hamcount = int(ham)
  596.         wi.spamcount = int(spam)
  597.         return wi
  598.  
  599.     
  600.     def uunquote(self, s):
  601.         for encoding in ('utf-8', 'cp1252', 'iso-8859-1'):
  602.             
  603.             try:
  604.                 return unicode(s, encoding)
  605.             continue
  606.             except UnicodeDecodeError:
  607.                 continue
  608.             
  609.  
  610.         
  611.         return s
  612.  
  613.     
  614.     def load(self):
  615.         if os.path.exists(self.db_name):
  616.             db = open(self.db_name, 'rb')
  617.             data = dict(cdb.Cdb(db))
  618.             db.close()
  619.             (self.nham, self.nspam) = [ int(i) for i in data[self.statekey].split(',') ]
  620.             self.wordinfo = [](_[1])
  621.             if options[('globals', 'verbose')]:
  622.                 print >>sys.stderr, '%s is an existing CDB, with %d ham and %d spam' % (self.db_name, self.nham, self.nspam)
  623.             
  624.         elif options[('globals', 'verbose')]:
  625.             print >>sys.stderr, self.db_name, 'is a new CDB'
  626.         
  627.         self.wordinfo = { }
  628.         self.nham = 0
  629.         self.nspam = 0
  630.  
  631.     
  632.     def store(self):
  633.         items = [
  634.             (self.statekey, '%d,%d' % (self.nham, self.nspam))]
  635.         for word, wi in self.wordinfo.iteritems():
  636.             if isinstance(word, types.UnicodeType):
  637.                 word = word.encode('utf-8')
  638.             
  639.             items.append((word, '%d,%d' % (wi.hamcount, wi.spamcount)))
  640.         
  641.         db = open(self.db_name, 'wb')
  642.         cdb.cdb_make(db, items)
  643.         db.close()
  644.  
  645.     
  646.     def close(self):
  647.         pass
  648.  
  649.  
  650.  
  651. try:
  652.     from persistent import Persistent
  653. except ImportError:
  654.     
  655.     try:
  656.         from ZODB import Persistent
  657.     except ImportError:
  658.         Persistent = object
  659.     except:
  660.         None<EXCEPTION MATCH>ImportError
  661.     
  662.  
  663.     None<EXCEPTION MATCH>ImportError
  664.  
  665.  
  666. class _PersistentClassifier(classifier.Classifier, Persistent):
  667.     
  668.     def __init__(self):
  669.         import ZODB
  670.         OOBTree = OOBTree
  671.         import BTrees.OOBTree
  672.         classifier.Classifier.__init__(self)
  673.         self.wordinfo = OOBTree()
  674.  
  675.  
  676.  
  677. class ZODBClassifier(object):
  678.     ClassifierClass = _PersistentClassifier
  679.     
  680.     def __init__(self, db_name, mode = 'c'):
  681.         self.db_filename = db_name
  682.         self.db_name = os.path.basename(db_name)
  683.         self.closed = True
  684.         self.mode = mode
  685.         self.load()
  686.  
  687.     
  688.     def __getattr__(self, att):
  689.         if hasattr(self, 'classifier') and hasattr(self.classifier, att):
  690.             return getattr(self.classifier, att)
  691.         
  692.         raise AttributeError("ZODBClassifier object has no attribute '%s'" % (att,))
  693.  
  694.     
  695.     def __setattr__(self, att, value):
  696.         if att in ('nham', 'nspam') and hasattr(self, 'classifier'):
  697.             setattr(self.classifier, att, value)
  698.         else:
  699.             object.__setattr__(self, att, value)
  700.  
  701.     
  702.     def create_storage(self):
  703.         import ZODB
  704.         FileStorage = FileStorage
  705.         import ZODB.FileStorage
  706.         self.storage = FileStorage(self.db_filename, read_only = self.mode == 'r')
  707.  
  708.     
  709.     def load(self):
  710.         '''Load state from database'''
  711.         import ZODB
  712.         if options[('globals', 'verbose')]:
  713.             print >>sys.stderr, 'Loading state from %s (%s) database' % (self.db_filename, self.db_name)
  714.         
  715.         if not self.closed:
  716.             self.close()
  717.         
  718.         self.create_storage()
  719.         self.DB = ZODB.DB(self.storage)
  720.         self.conn = self.DB.open()
  721.         root = self.conn.root()
  722.         self.classifier = root.get(self.db_name)
  723.         if self.classifier is None:
  724.             if options[('globals', 'verbose')]:
  725.                 print >>sys.stderr, self.db_name, 'is a new ZODB'
  726.             
  727.             self.classifier = root[self.db_name] = self.ClassifierClass()
  728.         elif options[('globals', 'verbose')]:
  729.             print >>sys.stderr, '%s is an existing ZODB, with %d ham and %d spam' % (self.db_name, self.nham, self.nspam)
  730.         
  731.         self.closed = False
  732.  
  733.     
  734.     def store(self):
  735.         '''Place state into persistent store'''
  736.         
  737.         try:
  738.             import ZODB
  739.             import ZODB.Transaction as ZODB
  740.         except ImportError:
  741.             import transaction
  742.             commit = transaction.commit
  743.             abort = transaction.abort
  744.  
  745.         commit = ZODB.Transaction.get_transaction().commit
  746.         abort = ZODB.Transaction.get_transaction().abort
  747.         ConflictError = ConflictError
  748.         import ZODB.POSException
  749.         ReadOnlyError = ReadOnlyError
  750.         import ZODB.POSException
  751.         TransactionFailedError = TransactionFailedError
  752.         import ZODB.POSException
  753.         if not self.closed == False:
  754.             raise AssertionError, "Can't store a closed database"
  755.         if options[('globals', 'verbose')]:
  756.             print >>sys.stderr, 'Persisting', self.db_name, 'state in database'
  757.         
  758.         
  759.         try:
  760.             commit()
  761.         except ConflictError:
  762.             if options[('globals', 'verbose')]:
  763.                 print >>sys.stderr, 'Conflict on commit', self.db_name
  764.             
  765.             abort()
  766.         except TransactionFailedError:
  767.             print >>sys.stderr, 'Storing failed.  Need to restart.', self.db_name
  768.             abort()
  769.         except ReadOnlyError:
  770.             print >>sys.stderr, "Can't store transaction to read-only db."
  771.             abort()
  772.  
  773.  
  774.     
  775.     def close(self):
  776.         if self.mode != 'r':
  777.             self.store()
  778.         
  779.         self.DB.close()
  780.         if self.mode != 'r' and hasattr(self.storage, 'pack'):
  781.             self.storage.pack(time.time() - 60 * 60 * 24, None)
  782.         
  783.         self.storage.close()
  784.         delattr(self, 'classifier')
  785.         self.closed = True
  786.         if options[('globals', 'verbose')]:
  787.             print >>sys.stderr, 'Closed', self.db_name, 'database'
  788.         
  789.  
  790.  
  791.  
  792. class ZEOClassifier(ZODBClassifier):
  793.     
  794.     def __init__(self, data_source_name):
  795.         source_info = data_source_name.split()
  796.         self.host = 'localhost'
  797.         self.port = None
  798.         db_name = 'SpamBayes'
  799.         for info in source_info:
  800.             if info.startswith('host'):
  801.                 self.host = info[5:]
  802.                 continue
  803.             if info.startswith('port'):
  804.                 self.port = int(info[5:])
  805.                 continue
  806.             if info.startswith('dbname'):
  807.                 db_name = info[7:]
  808.                 continue
  809.         
  810.         ZODBClassifier.__init__(self, db_name)
  811.  
  812.     
  813.     def create_storage(self):
  814.         ClientStorage = ClientStorage
  815.         import ZEO.ClientStorage
  816.         if self.port:
  817.             addr = (self.host, self.port)
  818.         else:
  819.             addr = self.host
  820.         self.storage = ClientStorage(addr)
  821.  
  822.  
  823. NO_TRAINING_FLAG = 1
  824.  
  825. class Trainer:
  826.     '''Associates a Classifier object and one or more Corpora,     is an observer of the corpora'''
  827.     
  828.     def __init__(self, bayes, is_spam, updateprobs = NO_UPDATEPROBS):
  829.         '''Constructor(Classifier, is_spam(True|False), updprobs(True|False)'''
  830.         self.bayes = bayes
  831.         self.is_spam = is_spam
  832.         self.updateprobs = updateprobs
  833.  
  834.     
  835.     def onAddMessage(self, message, flags = 0):
  836.         '''A message is being added to an observed corpus.'''
  837.         if not flags & NO_TRAINING_FLAG:
  838.             self.train(message)
  839.         
  840.  
  841.     
  842.     def train(self, message):
  843.         '''Train the database with the message'''
  844.         if options[('globals', 'verbose')]:
  845.             print >>sys.stderr, 'training with', message.key()
  846.         
  847.         self.bayes.learn(message.tokenize(), self.is_spam)
  848.         message.setId(message.key())
  849.         message.RememberTrained(self.is_spam)
  850.  
  851.     
  852.     def onRemoveMessage(self, message, flags = 0):
  853.         '''A message is being removed from an observed corpus.'''
  854.         if not flags & NO_TRAINING_FLAG:
  855.             self.untrain(message)
  856.         
  857.  
  858.     
  859.     def untrain(self, message):
  860.         '''Untrain the database with the message'''
  861.         if options[('globals', 'verbose')]:
  862.             print >>sys.stderr, 'untraining with', message.key()
  863.         
  864.         self.bayes.unlearn(message.tokenize(), self.is_spam)
  865.         message.RememberTrained(None)
  866.  
  867.     
  868.     def trainAll(self, corpus):
  869.         '''Train all the messages in the corpus'''
  870.         for msg in corpus:
  871.             self.train(msg)
  872.         
  873.  
  874.     
  875.     def untrainAll(self, corpus):
  876.         '''Untrain all the messages in the corpus'''
  877.         for msg in corpus:
  878.             self.untrain(msg)
  879.         
  880.  
  881.  
  882.  
  883. class SpamTrainer(Trainer):
  884.     '''Trainer for spam'''
  885.     
  886.     def __init__(self, bayes, updateprobs = NO_UPDATEPROBS):
  887.         '''Constructor'''
  888.         Trainer.__init__(self, bayes, True, updateprobs)
  889.  
  890.  
  891.  
  892. class HamTrainer(Trainer):
  893.     '''Trainer for ham'''
  894.     
  895.     def __init__(self, bayes, updateprobs = NO_UPDATEPROBS):
  896.         '''Constructor'''
  897.         Trainer.__init__(self, bayes, False, updateprobs)
  898.  
  899.  
  900.  
  901. class NoSuchClassifierError(Exception):
  902.     
  903.     def __init__(self, invalid_name):
  904.         self.invalid_name = invalid_name
  905.  
  906.     
  907.     def __str__(self):
  908.         return repr(self.invalid_name)
  909.  
  910.  
  911.  
  912. class MutuallyExclusiveError(Exception):
  913.     
  914.     def __str__(self):
  915.         return 'Only one type of database can be specified'
  916.  
  917.  
  918. _storage_types = {
  919.     'dbm': (DBDictClassifier, True, True),
  920.     'pickle': (PickledClassifier, False, True),
  921.     'pgsql': (PGClassifier, False, False),
  922.     'mysql': (mySQLClassifier, False, False),
  923.     'cdb': (CDBClassifier, False, True),
  924.     'zodb': (ZODBClassifier, True, True),
  925.     'zeo': (ZEOClassifier, False, False) }
  926.  
  927. def open_storage(data_source_name, db_type = 'dbm', mode = None):
  928.     '''Return a storage object appropriate to the given parameters.
  929.  
  930.     By centralizing this code here, all the applications will behave
  931.     the same given the same options.
  932.     '''
  933.     
  934.     try:
  935.         (klass, supports_mode, unused) = _storage_types[db_type]
  936.     except KeyError:
  937.         raise NoSuchClassifierError(db_type)
  938.  
  939.     
  940.     try:
  941.         if supports_mode and mode is not None:
  942.             return klass(data_source_name, mode)
  943.         else:
  944.             return klass(data_source_name)
  945.     except dbmstorage.error:
  946.         e = None
  947.         if str(e) == 'No dbm modules available!':
  948.             print >>sys.stderr, '\nYou do not have a dbm module available to use.  You need to either use a pickle (see the FAQ), use Python 2.3 (or above), or install a dbm module such as bsddb (see http://sf.net/projects/pybsddb).'
  949.             sys.exit()
  950.         
  951.         raise 
  952.  
  953.  
  954. _storage_options = {
  955.     '-p': 'pickle',
  956.     '-d': 'dbm' }
  957.  
  958. def database_type(opts, default_type = ('Storage', 'persistent_use_database'), default_name = ('Storage', 'persistent_storage_file')):
  959.     '''Return the name of the database and the type to use.  The output of
  960.     this function can be used as the db_type parameter for the open_storage
  961.     function, for example:
  962.  
  963.         [standard getopts code]
  964.         db_name, db_type = database_type(opts)
  965.         storage = open_storage(db_name, db_type)
  966.  
  967.     The selection is made based on the options passed, or, if the
  968.     appropriate options are not present, the options in the global
  969.     options object.
  970.  
  971.     Currently supports:
  972.        -p  :  pickle
  973.        -d  :  dbm
  974.     '''
  975.     (nm, typ) = (None, None)
  976.     for opt, arg in opts:
  977.         if _storage_options.has_key(opt):
  978.             if nm is None and typ is None:
  979.                 nm = arg
  980.                 typ = _storage_options[opt]
  981.             else:
  982.                 raise MutuallyExclusiveError()
  983.         typ is None
  984.     
  985.     if nm is None and typ is None:
  986.         typ = options[default_type]
  987.         
  988.         try:
  989.             (unused, unused, is_path) = _storage_types[typ]
  990.         except KeyError:
  991.             raise NoSuchClassifierError(db_type)
  992.  
  993.         if is_path:
  994.             nm = get_pathname_option(*default_name)
  995.         else:
  996.             nm = options[default_name]
  997.     
  998.     return (nm, typ)
  999.  
  1000.  
  1001. def convert(old_name = None, old_type = None, new_name = None, new_type = None):
  1002.     if old_name is None:
  1003.         old_name = 'hammie.db'
  1004.     
  1005.     if old_type is None:
  1006.         old_type = 'dbm'
  1007.     
  1008.     if new_name is None or new_type is None:
  1009.         (auto_name, auto_type) = database_type({ })
  1010.         if new_name is None:
  1011.             new_name = auto_name
  1012.         
  1013.         if new_type is None:
  1014.             new_type = auto_type
  1015.         
  1016.     
  1017.     old_bayes = open_storage(old_name, old_type, 'r')
  1018.     new_bayes = open_storage(new_name, new_type)
  1019.     words = old_bayes._wordinfokeys()
  1020.     
  1021.     try:
  1022.         new_bayes.nham = old_bayes.nham
  1023.     except AttributeError:
  1024.         new_bayes.nham = 0
  1025.  
  1026.     
  1027.     try:
  1028.         new_bayes.nspam = old_bayes.nspam
  1029.     except AttributeError:
  1030.         new_bayes.nspam = 0
  1031.  
  1032.     print >>sys.stderr, 'Converting %s (%s database) to %s (%s database).' % (old_name, old_type, new_name, new_type)
  1033.     print >>sys.stderr, 'Database has %s ham, %s spam, and %s words.' % (new_bayes.nham, new_bayes.nspam, len(words))
  1034.     for word in words:
  1035.         new_bayes._wordinfoset(word, old_bayes._wordinfoget(word))
  1036.     
  1037.     old_bayes.close()
  1038.     print >>sys.stderr, 'Storing database, please be patient...'
  1039.     new_bayes.store()
  1040.     print >>sys.stderr, 'Conversion complete.'
  1041.     new_bayes.close()
  1042.  
  1043.  
  1044. def ensureDir(dirname):
  1045.     '''Ensure that the given directory exists - in other words, if it
  1046.     does not exist, attempt to create it.'''
  1047.     
  1048.     try:
  1049.         os.mkdir(dirname)
  1050.         if options[('globals', 'verbose')]:
  1051.             print >>sys.stderr, 'Creating directory', dirname
  1052.     except OSError:
  1053.         e = None
  1054.         if e.errno != errno.EEXIST:
  1055.             raise 
  1056.         
  1057.     except:
  1058.         e.errno != errno.EEXIST
  1059.  
  1060.  
  1061. if __name__ == '__main__':
  1062.     print >>sys.stderr, __doc__
  1063.  
  1064.